
package org.apache.solr.request;

import java.io.IOException;
import java.util.*;
import java.util.concurrent.*;
import org.apache.lucene.index.AtomicReaderContext;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.DocIdSet;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.FieldCache;
import org.apache.lucene.search.Filter;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.CharsRef;
import org.apache.lucene.util.PriorityQueue;
import org.apache.lucene.util.UnicodeUtil;
import org.apache.lucene.util.packed.PackedInts;
import org.apache.solr.common.SolrException;
import org.apache.solr.common.params.FacetParams;
import org.apache.solr.common.util.NamedList;
import org.apache.solr.schema.FieldType;
import org.apache.solr.search.DocSet;
import org.apache.solr.search.SolrIndexSearcher;
import org.apache.solr.util.BoundedTreeSet;

class PerSegmentSingleValuedFaceting {

    // input params
    SolrIndexSearcher searcher;
    DocSet docs;
    String fieldName;
    int offset;
    int limit;
    int mincount;
    boolean missing;
    String sort;
    String prefix;
    Filter baseSet;
    int nThreads;

    public PerSegmentSingleValuedFaceting(SolrIndexSearcher searcher, DocSet docs, String fieldName, int offset, int limit, int mincount, boolean missing, String sort, String prefix) {
        this.searcher = searcher;
        this.docs = docs;
        this.fieldName = fieldName;
        this.offset = offset;
        this.limit = limit;
        this.mincount = mincount;
        this.missing = missing;
        this.sort = sort;
        this.prefix = prefix;
    }

    public void setNumThreads(int threads) {
        nThreads = threads;
    }

    NamedList<Integer> getFacetCounts(Executor executor) throws IOException {

        CompletionService<SegFacet> completionService = new ExecutorCompletionService<>(executor);

        // reuse the translation logic to go from top level set to per-segment set
        baseSet = docs.getTopFilter();

        final List<AtomicReaderContext> leaves = searcher.getTopReaderContext().leaves();
        // The list of pending tasks that aren't immediately submitted
        // TODO: Is there a completion service, or a delegating executor that can
        // limit the number of concurrent tasks submitted to a bigger executor?
        LinkedList<Callable<SegFacet>> pending = new LinkedList<>();

        int threads = nThreads <= 0 ? Integer.MAX_VALUE : nThreads;

        for(final AtomicReaderContext leave : leaves) {
            final SegFacet segFacet = new SegFacet(leave);

            Callable<SegFacet> task = new Callable<SegFacet>() {
                @Override
                public SegFacet call() throws Exception {
                    segFacet.countTerms();
                    return segFacet;
                }
            };

            // TODO: if limiting threads, submit by largest segment first?

            if(--threads >= 0) {
                completionService.submit(task);
            }
            else {
                pending.add(task);
            }
        }


        // now merge the per-segment results
        PriorityQueue<SegFacet> queue = new PriorityQueue<SegFacet>(leaves.size()) {
            @Override
            protected boolean lessThan(SegFacet a, SegFacet b) {
                return a.tempBR.compareTo(b.tempBR) < 0;
            }
        };


        boolean hasMissingCount = false;
        int missingCount = 0;
        for(int i = 0, c = leaves.size(); i < c; i++) {
            SegFacet seg = null;

            try {
                Future<SegFacet> future = completionService.take();
                seg = future.get();
                if(!pending.isEmpty()) {
                    completionService.submit(pending.removeFirst());
                }
            }
            catch(InterruptedException e) {
                Thread.currentThread().interrupt();
                throw new SolrException(SolrException.ErrorCode.SERVER_ERROR, e);
            }
            catch(ExecutionException e) {
                Throwable cause = e.getCause();
                if(cause instanceof RuntimeException) {
                    throw (RuntimeException)cause;
                }
                else {
                    throw new SolrException(SolrException.ErrorCode.SERVER_ERROR, "Error in per-segment faceting on field: " + fieldName, cause);
                }
            }


            if(seg.startTermIndex < seg.endTermIndex) {
                if(seg.startTermIndex == 0) {
                    hasMissingCount = true;
                    missingCount += seg.counts[0];
                    seg.pos = 1;
                }
                else {
                    seg.pos = seg.startTermIndex;
                }
                if(seg.pos < seg.endTermIndex) {
                    seg.tenum = seg.si.getTermsEnum();
                    seg.tenum.seekExact(seg.pos);
                    seg.tempBR = seg.tenum.term();
                    queue.add(seg);
                }
            }
        }

        FacetCollector collector;
        if(sort.equals(FacetParams.FACET_SORT_COUNT) || sort.equals(FacetParams.FACET_SORT_COUNT_LEGACY)) {
            collector = new CountSortedFacetCollector(offset, limit, mincount);
        }
        else {
            collector = new IndexSortedFacetCollector(offset, limit, mincount);
        }

        BytesRef val = new BytesRef();

        while(queue.size() > 0) {
            SegFacet seg = queue.top();

            // make a shallow copy
            val.bytes = seg.tempBR.bytes;
            val.offset = seg.tempBR.offset;
            val.length = seg.tempBR.length;

            int count = 0;

            do {
                count += seg.counts[seg.pos - seg.startTermIndex];

                // TODO: OPTIMIZATION...
                // if mincount>0 then seg.pos++ can skip ahead to the next non-zero entry.
                seg.pos++;
                if(seg.pos >= seg.endTermIndex) {
                    queue.pop();
                    seg = queue.top();
                }
                else {
                    seg.tempBR = seg.tenum.next();
                    seg = queue.updateTop();
                }
            }
            while(seg != null && val.compareTo(seg.tempBR) == 0);

            boolean stop = collector.collect(val, count);
            if(stop) {
                break;
            }
        }

        NamedList<Integer> res = collector.getFacetCounts();

        // convert labels to readable form    
        FieldType ft = searcher.getSchema().getFieldType(fieldName);
        int sz = res.size();
        for(int i = 0; i < sz; i++) {
            res.setName(i, ft.indexedToReadable(res.getName(i)));
        }

        if(missing) {
            if(!hasMissingCount) {
                missingCount = SimpleFacets.getFieldMissingCount(searcher, docs, fieldName);
            }
            res.add(null, missingCount);
        }

        return res;
    }

    class SegFacet {

        AtomicReaderContext context;

        SegFacet(AtomicReaderContext context) {
            this.context = context;
        }
        FieldCache.DocTermsIndex si;
        int startTermIndex;
        int endTermIndex;
        int[] counts;
        int pos; // only used when merging
        TermsEnum tenum; // only used when merging
        BytesRef tempBR = new BytesRef();

        void countTerms() throws IOException {

            si = FieldCache.DEFAULT.getTermsIndex(context.reader(), fieldName);
            // SolrCore.log.info("reader= " + reader + "  FC=" + System.identityHashCode(si));

            if(prefix != null) {
                BytesRef prefixRef = new BytesRef(prefix);
                startTermIndex = si.binarySearchLookup(prefixRef, tempBR);
                if(startTermIndex < 0) {
                    startTermIndex = -startTermIndex - 1;
                }
                prefixRef.append(UnicodeUtil.BIG_TERM);
                // TODO: we could constrain the lower endpoint if we had a binarySearch method that allowed passing start/end
                endTermIndex = si.binarySearchLookup(prefixRef, tempBR);
                assert endTermIndex < 0;
                endTermIndex = -endTermIndex - 1;
            }
            else {
                startTermIndex = 0;
                endTermIndex = si.numOrd();
            }

            final int nTerms = endTermIndex - startTermIndex;
            if(nTerms > 0) {
                // count collection array only needs to be as big as the number of terms we are
                // going to collect counts for.
                final int[] lcounts = this.counts = new int[nTerms];
                DocIdSet idSet = baseSet.getDocIdSet(context, null);  // this set only includes live docs
                DocIdSetIterator iter = idSet.iterator();


                ////
                PackedInts.Reader ordReader = si.getDocToOrd();
                int doc;

                final Object arr;
                if(ordReader.hasArray()) {
                    arr = ordReader.getArray();
                }
                else {
                    arr = null;
                }

                if(arr instanceof int[]) {
                    int[] ords = (int[])arr;
                    if(prefix == null) {
                        while((doc = iter.nextDoc()) < DocIdSetIterator.NO_MORE_DOCS) {
                            lcounts[ords[doc]]++;
                        }
                    }
                    else {
                        while((doc = iter.nextDoc()) < DocIdSetIterator.NO_MORE_DOCS) {
                            int term = ords[doc];
                            int arrIdx = term - startTermIndex;
                            if(arrIdx >= 0 && arrIdx < nTerms) {
                                lcounts[arrIdx]++;
                            }
                        }
                    }
                }
                else if(arr instanceof short[]) {
                    short[] ords = (short[])arr;
                    if(prefix == null) {
                        while((doc = iter.nextDoc()) < DocIdSetIterator.NO_MORE_DOCS) {
                            lcounts[ords[doc] & 0xffff]++;
                        }
                    }
                    else {
                        while((doc = iter.nextDoc()) < DocIdSetIterator.NO_MORE_DOCS) {
                            int term = ords[doc] & 0xffff;
                            int arrIdx = term - startTermIndex;
                            if(arrIdx >= 0 && arrIdx < nTerms) {
                                lcounts[arrIdx]++;
                            }
                        }
                    }
                }
                else if(arr instanceof byte[]) {
                    byte[] ords = (byte[])arr;
                    if(prefix == null) {
                        while((doc = iter.nextDoc()) < DocIdSetIterator.NO_MORE_DOCS) {
                            lcounts[ords[doc] & 0xff]++;
                        }
                    }
                    else {
                        while((doc = iter.nextDoc()) < DocIdSetIterator.NO_MORE_DOCS) {
                            int term = ords[doc] & 0xff;
                            int arrIdx = term - startTermIndex;
                            if(arrIdx >= 0 && arrIdx < nTerms) {
                                lcounts[arrIdx]++;
                            }
                        }
                    }
                }
                else {
                    if(prefix == null) {
                        // specialized version when collecting counts for all terms
                        while((doc = iter.nextDoc()) < DocIdSetIterator.NO_MORE_DOCS) {
                            lcounts[si.getOrd(doc)]++;
                        }
                    }
                    else {
                        // version that adjusts term numbers because we aren't collecting the full range
                        while((doc = iter.nextDoc()) < DocIdSetIterator.NO_MORE_DOCS) {
                            int term = si.getOrd(doc);
                            int arrIdx = term - startTermIndex;
                            if(arrIdx >= 0 && arrIdx < nTerms) {
                                lcounts[arrIdx]++;
                            }
                        }
                    }
                }
            }
        }
    }
}

abstract class FacetCollector {

    // return true to stop collection
    public abstract boolean collect(BytesRef term, int count);

    public abstract NamedList<Integer> getFacetCounts();
}

// This collector expects facets to be collected in index order
class CountSortedFacetCollector extends FacetCollector {

    private final CharsRef spare = new CharsRef();
    final int offset;
    final int limit;
    final int maxsize;
    final BoundedTreeSet<SimpleFacets.CountPair<String, Integer>> queue;
    int min;  // the smallest value in the top 'N' values

    public CountSortedFacetCollector(int offset, int limit, int mincount) {
        this.offset = offset;
        this.limit = limit;
        maxsize = limit > 0 ? offset + limit : Integer.MAX_VALUE - 1;
        queue = new BoundedTreeSet<>(maxsize);
        min = mincount - 1;  // the smallest value in the top 'N' values
    }

    @Override
    public boolean collect(BytesRef term, int count) {

        if(count > min) {
            // NOTE: we use c>min rather than c>=min as an optimization because we are going in
            // index order, so we already know that the keys are ordered.  This can be very
            // important if a lot of the counts are repeated (like zero counts would be).
            UnicodeUtil.UTF8toUTF16(term, spare);
            queue.add(new SimpleFacets.CountPair<>(spare.toString(), count));
            if(queue.size() >= maxsize) {
                min = queue.last().val;
            }
        }
        return false;
    }

    @Override
    public NamedList<Integer> getFacetCounts() {

        NamedList<Integer> res = new NamedList<>();
        int off = offset;
        int lim = limit >= 0 ? limit : Integer.MAX_VALUE;
        // now select the right page from the results
        for(SimpleFacets.CountPair<String, Integer> p : queue) {
            if(--off >= 0) {
                continue;
            }
            if(--lim < 0) {
                break;
            }
            res.add(p.key, p.val);
        }
        return res;
    }
}

// This collector expects facets to be collected in index order
class IndexSortedFacetCollector extends FacetCollector {

    private final CharsRef spare = new CharsRef();
    int offset;
    int limit;
    final int mincount;
    final NamedList<Integer> res = new NamedList<Integer>();

    public IndexSortedFacetCollector(int offset, int limit, int mincount) {
        this.offset = offset;
        this.limit = limit > 0 ? limit : Integer.MAX_VALUE;
        this.mincount = mincount;
    }

    @Override
    public boolean collect(BytesRef term, int count) {

        if(count < mincount) {
            return false;
        }

        if(offset > 0) {
            offset--;
            return false;
        }

        if(limit > 0) {
            UnicodeUtil.UTF8toUTF16(term, spare);
            res.add(spare.toString(), count);
            limit--;
        }

        return limit <= 0;
    }

    @Override
    public NamedList<Integer> getFacetCounts() {
        return res;
    }
}
